{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualizations for the paper.\n",
    "We visualized a bunch of attention shifts in the paper. This notebook guides you through how to create your own visualizations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from iterator import SmartIterator\n",
    "from utils.visualization_utils import get_att_map, objdict, get_dict, add_attention, add_bboxes, get_bbox_from_heatmap, add_bbox_to_image\n",
    "from keras.models import load_model\n",
    "from old_models import ReferringRelationshipsModel\n",
    "from keras.utils import to_categorical\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image\n",
    "import json\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import h5py\n",
    "from keras.models import Model\n",
    "import seaborn as sns\n",
    "from scipy.misc import imresize\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 34})\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose the dataset you want to visualize some results for."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###################\n",
    "data_type = \"clevr\"\n",
    "###################\n",
    "if data_type==\"vrd\":\n",
    "    #annotations_file = \"/data/chami/ReferringRelationships/data/VRD/vrd_rels_multiple_objects.json\"\n",
    "    annotations_file = \"data/VRD/annotations_test.json\"\n",
    "    img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'\n",
    "    vocab_dir = os.path.join('data/VRD')\n",
    "    model_checkpoint = \"/data/chami/tmp/model19-2.09.h5\"\n",
    "    #model_checkpoint =\"/data/chami/ReferringRelationships/models/VRD/11_09_2017/baseline_no_predicate/1/model26-1.68.h5\"\n",
    "    #model_checkpoint= \"/data/ranjaykrishna/ReferringRelationships/temp/vrd_iter1/model29-1.41.h5\"\n",
    "    #model_checkpoint = \"/data/chami/ReferringRelationships/models/VRD/11_09_2017/baseline/1/model21-1.59.h5\"\n",
    "    #model_checkpoint= \"/data/chami/ReferringRelationships/models/VRD/11_07_2017/baseline_no_predicate/8/model24-1.38.h5\"\n",
    "elif data_type==\"clevr\":\n",
    "    annotations_file = \"/data/chami/ReferringRelationships/data/Clevr/annotations_test.json\"\n",
    "    img_dir = '/data/ranjaykrishna/clevr/images/val/'\n",
    "    vocab_dir = os.path.join('/data/chami/ReferringRelationships/data/Clevr/')\n",
    "    model_checkpoint = \"/data/chami/ReferringRelationships/models/Clevr/11_13_2017/sym_ssn/1/model05-0.26.h5\"\n",
    "elif data_type==\"visualgenome\":\n",
    "    annotations_file = \"data/VisualGenome/annotations_test.json\"\n",
    "    img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'\n",
    "    vocab_dir = os.path.join('data/VisualGenome')\n",
    "    model_checkpoint = \"/data/ranjaykrishna/ReferringRelationships/temp/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "annotations_test = json.load(open(annotations_file))\n",
    "predicate_dict, obj_subj_dict = get_dict(vocab_dir)\n",
    "image_ids = sorted(list(annotations_test.keys()))[:1000]\n",
    "params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), \"args.json\"), \"r\")))\n",
    "relationships_model = ReferringRelationshipsModel(params)\n",
    "test_generator = SmartIterator(params.test_data_dir, params)\n",
    "images = test_generator.get_image_dataset()\n",
    "print(' | '.join(obj_subj_dict))\n",
    "print('')\n",
    "print(' | '.join(predicate_dict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = relationships_model.build_model()\n",
    "model.load_weights(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_layers(name, offset=0):\n",
    "    attentions = []\n",
    "    while True:\n",
    "        layer_name = name.format(len(attentions) + offset)\n",
    "        try:\n",
    "            heatmap = model.get_layer(layer_name).output\n",
    "            attentions.append(heatmap)\n",
    "        except:\n",
    "            break\n",
    "    return attentions\n",
    "\n",
    "subject_attentions = get_layers(\"subject-att-{}\")\n",
    "object_attentions = get_layers(\"object-att-{}\")\n",
    "shift_attentions = get_layers(\"shift-{}\", offset=1)\n",
    "inv_shift_attentions = get_layers(\"inv-shift-{}\", offset=1)\n",
    "\n",
    "print(\"Found {} subject attentions\".format(len(subject_attentions)))\n",
    "print(\"Found {} object attentions\".format(len(object_attentions)))\n",
    "print(\"Found {} shift attentions\".format(len(shift_attentions)))\n",
    "print(\"Found {} inv shift attentions\".format(len(inv_shift_attentions)))\n",
    "\n",
    "all_attentions = subject_attentions + object_attentions + shift_attentions + inv_shift_attentions\n",
    "attention_model = Model(inputs=model.input, outputs=all_attentions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### USER INPUT - Pick an image "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################\n",
    "image_index = np.random.randint(100)\n",
    "print(image_index)\n",
    "#################\n",
    "img = Image.open(os.path.join(img_dir, image_ids[image_index]))\n",
    "img = img.resize((params.input_dim, params.input_dim))\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.imshow(img)\n",
    "plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### USER INPUT - Pick a relationship "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################\n",
    "subj = \"green_metal_cube\"\n",
    "predicate = \"left\"\n",
    "obj = \"green_metal_cube\"\n",
    "#################\n",
    "subj_id = np.zeros((1, 1))\n",
    "predicate_id = np.zeros((1, params.num_predicates))\n",
    "#predicate_id = np.zeros((1, 1))\n",
    "obj_id = np.zeros((1, 1))\n",
    "relationship = [subj, predicate, obj]\n",
    "subj_id[0, 0] = obj_subj_dict.index(subj)\n",
    "predicate_id[0, predicate_dict.index(predicate)] = 1\n",
    "#predicate_id[0, 0] = predicate_dict.index(predicate)\n",
    "obj_id[0, 0] = obj_subj_dict.index(obj)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the model and visualize the heatmaps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_heatmaps = model.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])\n",
    "#all_heatmaps = model.predict([images[image_index:image_index+1], subj_id, obj_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 15))\n",
    "thresh = 0.9\n",
    "ax = plt.subplot2grid((1, 3), (0, 0), colspan=1, rowspan=1)\n",
    "ax.imshow(img)\n",
    "ax.axis(\"off\")\n",
    "ax = plt.subplot2grid((1, 3), (0, 1), colspan=1, rowspan=1)\n",
    "ax.imshow((all_heatmaps[0][0]>thresh).reshape((params.output_dim, params.output_dim)), interpolation='spline16')\n",
    "ax.axis(\"off\")\n",
    "ax = plt.subplot2grid((1, 3), (0, 2), colspan=1, rowspan=1)\n",
    "ax.imshow((all_heatmaps[1][0]>thresh).reshape((params.output_dim, params.output_dim)), interpolation='spline16')\n",
    "ax.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "att_map = get_att_map(img, np.maximum(all_heatmaps[0],0), np.maximum(all_heatmaps[1],0), params.feat_map_dim, relationship)\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.imshow(att_map)\n",
    "plt.title(\"Final Heatmaps: \" + \"-\".join(relationship))\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get all the predictions we want to visualize\n",
    "all_heatmaps = attention_model.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])\n",
    "all_heatmaps = attention_model.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])\n",
    "subject_heatmaps = all_heatmaps[0:len(subject_attentions)]\n",
    "object_heatmaps = all_heatmaps[len(subject_attentions):len(subject_attentions)+len(object_attentions)]\n",
    "shift_heatmaps = all_heatmaps[len(subject_attentions)+len(object_attentions):len(subject_attentions)+len(object_attentions)+len(shift_attentions)]\n",
    "inv_shift_heatmaps = all_heatmaps[len(subject_attentions)+len(object_attentions)+len(shift_attentions):len(subject_attentions)+len(object_attentions)+len(shift_attentions)+len(inv_shift_attentions)]\n",
    "final_subject_heatmap, final_object_heatmap = model.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])\n",
    "\n",
    "# Visualize heatmaps\n",
    "att_map = get_att_map(img, np.maximum(final_subject_heatmap[0],0), np.maximum(final_object_heatmap[0],0), params.feat_map_dim, relationship)\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.imshow(att_map)\n",
    "plt.title(\"Final Heatmaps: \" + \"-\".join(relationship))\n",
    "plt.axis(\"off\")\n",
    "plt.show()\n",
    "\n",
    "# Visualize bounding boxes.\n",
    "att_map = add_bboxes(img, np.maximum(final_subject_heatmap[0],0), \n",
    "                     np.maximum(final_object_heatmap[0],0), 14, threshold=0.8)\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.imshow(att_map)\n",
    "plt.title(\"Final Bounding boxes: \" + \"-\".join(relationship))\n",
    "plt.axis(\"off\")\n",
    "plt.show()\n",
    "\n",
    "# Visualize attention over iterations.\n",
    "def display_attention_heatmaps(heatmaps, title):\n",
    "    fig, axs = plt.subplots(nrows=1, ncols=len(heatmaps), figsize=(40,5))\n",
    "    for idx in range(len(heatmaps)):\n",
    "        ax = axs[idx]\n",
    "        #att = add_attention(img, np.maximum(heatmaps[idx],0), params.input_dim)\n",
    "        att = heatmaps[idx].reshape((params.feat_map_dim, params.feat_map_dim))\n",
    "        att[att < np.max(params.heatmap_threshold)] = 0\n",
    "        ax.imshow(att, interpolation='spline16')\n",
    "        #sns.heatmap(att, annot=True, linewidths=.5, ax=ax)\n",
    "        ax.set_title(title + ' iteration-{}'.format(idx))\n",
    "        ax.axis(\"off\")\n",
    "    plt.show()\n",
    "\n",
    "display_attention_heatmaps(subject_heatmaps, 'subject attentions')\n",
    "display_attention_heatmaps(object_heatmaps, 'object attentions')\n",
    "display_attention_heatmaps(shift_heatmaps, 'shift heatmaps')\n",
    "display_attention_heatmaps(inv_shift_heatmaps, 'inv shift heatmaps')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualizations for the paper.\n",
    "############################\n",
    "iterations_to_show = 3\n",
    "threshold = 0.3\n",
    "#############################\n",
    "ncols = iterations_to_show*2 + 4\n",
    "nrows = 4\n",
    "fig = plt.figure(figsize=(14, 6))\n",
    "\n",
    "ax = plt.subplot2grid((nrows, ncols), (1, 0), colspan=2, rowspan=2)\n",
    "ax.imshow(img)\n",
    "ax.axis(\"off\")\n",
    "\n",
    "s_bbox = get_bbox_from_heatmap(final_subject_heatmap[0], threshold=threshold, input_dim=14)\n",
    "s_image = add_bbox_to_image(img, s_bbox, color='blue', width=3)\n",
    "ax = plt.subplot2grid((nrows, ncols), (0, iterations_to_show*2+2), colspan=2, rowspan=2)\n",
    "ax.imshow(s_image)\n",
    "ax.axis(\"off\")\n",
    "\n",
    "o_bbox = get_bbox_from_heatmap(final_object_heatmap[0], threshold=threshold, input_dim=14)\n",
    "o_image = add_bbox_to_image(img, o_bbox, color='green', width=3)\n",
    "ax = plt.subplot2grid((nrows, ncols), (2, iterations_to_show*2+2), colspan=2, rowspan=2)\n",
    "ax.imshow(o_image)\n",
    "ax.axis(\"off\")\n",
    "\n",
    "for iteration in range(iterations_to_show):\n",
    "    s_att = subject_heatmaps[iteration].reshape((params.feat_map_dim, params.feat_map_dim))\n",
    "    s_att[s_att < np.max(params.heatmap_threshold)] = 0\n",
    "    ax = plt.subplot2grid((nrows, ncols), (0, 2*iteration+2), colspan=2, rowspan=2)\n",
    "    ax.imshow(s_att, interpolation='spline16')\n",
    "    ax.axis(\"off\")\n",
    "                          \n",
    "    o_att = object_heatmaps[iteration].reshape((params.feat_map_dim, params.feat_map_dim))\n",
    "    o_att[o_att < np.max(params.heatmap_threshold)] = 0\n",
    "    ax = plt.subplot2grid((nrows, ncols), (2, 2*iteration+2), colspan=2, rowspan=2)\n",
    "    ax.imshow(o_att, interpolation='spline16')\n",
    "    ax.axis(\"off\")\n",
    "    \n",
    "plt.tight_layout(pad=0.1, w_pad=-1, h_pad=-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model.get_layer(\"subject-att-1\").output\n",
    "before_pred = Model(inputs=model.input, outputs=output)\n",
    "output = model.get_layer(\"object-att-1\").output\n",
    "after_pred = Model(inputs=model.input, outputs=output)\n",
    "output = model.get_layer(\"shift-1\").output\n",
    "shift = Model(inputs=model.input, outputs=output)\n",
    "interp_method = 'gaussian'\n",
    "map_1 = before_pred.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])\n",
    "map_2 = after_pred.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])\n",
    "map_3 = shift.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])\n",
    "fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
    "img_1 = map_1.reshape((params.feat_map_dim, params.feat_map_dim))\n",
    "plot0 = axes[0].imshow(img_1, interpolation=interp_method)\n",
    "plot1 = axes[1].imshow(map_2.reshape((params.feat_map_dim, params.feat_map_dim)), interpolation=interp_method)\n",
    "fig.colorbar(plot0, ax=axes[0])\n",
    "axes[0].axis(\"off\")\n",
    "axes[0].set_title(\"before {}\".format(predicate))\n",
    "axes[1].axis(\"off\")\n",
    "axes[1].set_title(\"after {}\".format(predicate))\n",
    "fig.colorbar(plot1, ax=axes[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "map_1 = before_pred.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id]).reshape((params.feat_map_dim, params.feat_map_dim))\n",
    "map_2 = after_pred.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id]).reshape((params.feat_map_dim, params.feat_map_dim))\n",
    "map_3 = shift.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id]).reshape((params.feat_map_dim, params.feat_map_dim))\n",
    "fig, axes = plt.subplots(1, 3, figsize=(30, 7))\n",
    "sns.heatmap(map_1, annot=True, linewidths=.5, ax=axes[0])\n",
    "sns.heatmap(map_3, annot=True, linewidths=.5, ax=axes[1])\n",
    "sns.heatmap(map_2, annot=True, linewidths=.5, ax=axes[2])\n",
    "for i in range(3):\n",
    "    axes[i].axis(\"off\")\n",
    "axes[0].set_title(\"before-pred\")\n",
    "axes[1].set_title(\"shfted {}\".format(predicate))\n",
    "axes[2].set_title(\"after-pred {}\".format(predicate))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
